blob: bacc76a69f15d962786ed310a43951e718b19f62 [file] [log] [blame]
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <Python.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#if PY_MAJOR_VERSION >= 3
#if PY_VERSION_HEX < 0x03030000
#error "Python 3.0 - 3.2 are not supported."
#endif
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob)? \
((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \
PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#endif
namespace google {
namespace protobuf {
namespace python {
namespace message_factory {
PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
PyType_GenericAlloc(type, 0));
if (factory == NULL) {
return NULL;
}
DynamicMessageFactory* message_factory = new DynamicMessageFactory();
// This option might be the default some day.
message_factory->SetDelegateToGeneratedFactory(true);
factory->message_factory = message_factory;
factory->pool = pool;
// TODO(amauryfa): When the MessageFactory is not created from the
// DescriptorPool this reference should be owned, not borrowed.
// Py_INCREF(pool);
factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
return factory;
}
PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
static char* kwlist[] = {"pool", 0};
PyObject* pool = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &pool)) {
return NULL;
}
ScopedPyObjectPtr owned_pool;
if (pool == NULL || pool == Py_None) {
owned_pool.reset(PyObject_CallFunction(
reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), NULL));
if (owned_pool == NULL) {
return NULL;
}
pool = owned_pool.get();
} else {
if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) {
PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s",
pool->ob_type->tp_name);
return NULL;
}
}
return reinterpret_cast<PyObject*>(
NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool)));
}
static void Dealloc(PyObject* pself) {
PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
// TODO(amauryfa): When the MessageFactory is not created from the
// DescriptorPool this reference should be owned, not borrowed.
// Py_CLEAR(self->pool);
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
for (iterator it = self->classes_by_descriptor->begin();
it != self->classes_by_descriptor->end(); ++it) {
Py_DECREF(it->second);
}
delete self->classes_by_descriptor;
delete self->message_factory;
Py_TYPE(self)->tp_free(pself);
}
// Add a message class to our database.
int RegisterMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor,
CMessageClass* message_class) {
Py_INCREF(message_class);
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
std::make_pair(message_descriptor, message_class));
if (!ret.second) {
// Update case: DECREF the previous value.
Py_DECREF(ret.first->second);
ret.first->second = message_class;
}
return 0;
}
CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
const Descriptor* descriptor) {
// This is the same implementation as MessageFactory.GetPrototype().
// Do not create a MessageClass that already exists.
hash_map<const Descriptor*, CMessageClass*>::iterator it =
self->classes_by_descriptor->find(descriptor);
if (it != self->classes_by_descriptor->end()) {
Py_INCREF(it->second);
return it->second;
}
ScopedPyObjectPtr py_descriptor(
PyMessageDescriptor_FromDescriptor(descriptor));
if (py_descriptor == NULL) {
return NULL;
}
// Create a new message class.
ScopedPyObjectPtr args(Py_BuildValue(
"s(){sOsOsO}", descriptor->name().c_str(),
"DESCRIPTOR", py_descriptor.get(),
"__module__", Py_None,
"message_factory", self));
if (args == NULL) {
return NULL;
}
ScopedPyObjectPtr message_class(PyObject_CallObject(
reinterpret_cast<PyObject*>(&CMessageClass_Type), args.get()));
if (message_class == NULL) {
return NULL;
}
// Create messages class for the messages used by the fields, and registers
// all extensions for these messages during the recursion.
for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
const Descriptor* sub_descriptor =
descriptor->field(field_idx)->message_type();
// It is NULL if the field type is not a message.
if (sub_descriptor != NULL) {
CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor);
if (result == NULL) {
return NULL;
}
Py_DECREF(result);
}
}
// Register extensions defined in this message.
for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) {
const FieldDescriptor* extension = descriptor->extension(ext_idx);
ScopedPyObjectPtr py_extended_class(
GetOrCreateMessageClass(self, extension->containing_type())
->AsPyObject());
if (py_extended_class == NULL) {
return NULL;
}
ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension));
if (py_extension == NULL) {
return NULL;
}
ScopedPyObjectPtr result(cmessage::RegisterExtension(
py_extended_class.get(), py_extension.get()));
if (result == NULL) {
return NULL;
}
}
return reinterpret_cast<CMessageClass*>(message_class.release());
}
// Retrieve the message class added to our database.
CMessageClass* GetMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor) {
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
iterator ret = self->classes_by_descriptor->find(message_descriptor);
if (ret == self->classes_by_descriptor->end()) {
PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
message_descriptor->full_name().c_str());
return NULL;
} else {
return ret->second;
}
}
static PyMethodDef Methods[] = {
{NULL}};
static PyObject* GetPool(PyMessageFactory* self, void* closure) {
Py_INCREF(self->pool);
return reinterpret_cast<PyObject*>(self->pool);
}
static PyGetSetDef Getters[] = {
{"pool", (getter)GetPool, NULL, "DescriptorPool"},
{NULL}
};
} // namespace message_factory
PyTypeObject PyMessageFactory_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
".MessageFactory", // tp_name
sizeof(PyMessageFactory), // tp_basicsize
0, // tp_itemsize
message_factory::Dealloc, // tp_dealloc
0, // tp_print
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
0, // tp_repr
0, // tp_as_number
0, // tp_as_sequence
0, // tp_as_mapping
0, // tp_hash
0, // tp_call
0, // tp_str
0, // tp_getattro
0, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags
"A static Message Factory", // tp_doc
0, // tp_traverse
0, // tp_clear
0, // tp_richcompare
0, // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
message_factory::Methods, // tp_methods
0, // tp_members
message_factory::Getters, // tp_getset
0, // tp_base
0, // tp_dict
0, // tp_descr_get
0, // tp_descr_set
0, // tp_dictoffset
0, // tp_init
0, // tp_alloc
message_factory::New, // tp_new
PyObject_Del, // tp_free
};
bool InitMessageFactory() {
if (PyType_Ready(&PyMessageFactory_Type) < 0) {
return false;
}
return true;
}
} // namespace python
} // namespace protobuf
} // namespace google