1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
|
diff --git a/vs_mxnet/vsMXNet.cpp b/vs_mxnet/vsMXNet.cpp
index 6190ff2..f9e1184 100644
--- a/vs_mxnet/vsMXNet.cpp
+++ b/vs_mxnet/vsMXNet.cpp
@@ -3,20 +3,10 @@
#include <algorithm>
#include <vector>
-#include <VapourSynth/VapourSynth.h>
-#include <VapourSynth/VSHelper.h>
+#include <VapourSynth.h>
+#include <VSHelper.h>
-#include "MXDll.h"
-
-#ifdef _MSC_VER
-#if defined (_WINDEF_) && defined(min) && defined(max)
-#undef min
-#undef max
-#endif
-#ifndef NOMINMAX
-#define NOMINMAX
-#endif
-#endif
+#include <mxnet/c_predict_api.h>
// no int8 and uint16
inline int VSFormatToMXDtype(const VSFormat *format)
@@ -81,18 +71,16 @@ std::vector<char> ReadFile(const std::string &file_path)
return buf;
}
-MXNet mx("libmxnet.dll");
-
inline int mxForward(mxnetData * VS_RESTRICT d)
{
int ch = d->vi.format->numPlanes;
auto imageSize = d->patch_h * d->patch_w * ch;
- if (mx.MXPredSetInput(d->hPred, "data", (float *)d->srcBuffer, imageSize) != 0) {
+ if (MXPredSetInput(d->hPred, "data", (float *)d->srcBuffer, imageSize) != 0) {
return 2;
}
- if (mx.MXPredForward(d->hPred) != 0) {
+ if (MXPredForward(d->hPred) != 0) {
return 2;
}
@@ -102,7 +90,7 @@ inline int mxForward(mxnetData * VS_RESTRICT d)
uint32_t shape_len = 0;
// Get Output Result
- if (mx.MXPredGetOutputShape(d->hPred, output_index, &shape, &shape_len) != 0) {
+ if (MXPredGetOutputShape(d->hPred, output_index, &shape, &shape_len) != 0) {
return 2;
}
@@ -113,7 +101,7 @@ inline int mxForward(mxnetData * VS_RESTRICT d)
return 1;
}
- if (mx.MXPredGetOutput(d->hPred, output_index, (float *)d->dstBuffer, outputSize) != 0) {
+ if (MXPredGetOutput(d->hPred, output_index, (float *)d->dstBuffer, outputSize) != 0) {
return 2;
}
@@ -208,7 +196,7 @@ static const VSFrameRef *VS_CC mxGetFrame(int n, int activationReason, void **in
err = "mxnet: input and target shapes do not match";
else if (error == 2) {
err = "mxnet: failed to process: ";
- err += mx.MXGetLastError();
+ err += MXGetLastError();
}
else if (error == 3)
err = "mxnet: not support clip format";
@@ -231,7 +219,7 @@ static void VS_CC mxFree(void *instanceData, VSCore *core, const VSAPI *vsapi)
mxnetData *d = static_cast<mxnetData *>(instanceData);
vsapi->freeNode(d->node);
- mx.MXPredFree(d->hPred);
+ MXPredFree(d->hPred);
vs_aligned_free(d->srcBuffer);
vs_aligned_free(d->dstBuffer);
@@ -439,19 +427,11 @@ static void VS_CC mxCreate(const VSMap *in, VSMap *out, void *userData, VSCore *
d.hPred = nullptr;
- if (!mx.IsInit()) {
- mx.LoadDll(nullptr);
- }
-
- if (!mx.IsInit()) {
- throw std::string{ "Cannot load MXNet. Please check MXNet installation." };
- }
-
const char *arg_dtype_names[] = { "data" };
int arg_dtype[1] = { input_dtype };
// Create Predictor
- if (mx.MXPredCreateEx(
+ if (MXPredCreateEx(
json_data.data(), param_data.data(),
static_cast<int>(param_data.size()),
dev_type, dev_id,
@@ -459,11 +439,11 @@ static void VS_CC mxCreate(const VSMap *in, VSMap *out, void *userData, VSCore *
input_keys, input_shape_indptr, input_shape_data,
1, arg_dtype_names, arg_dtype,
&d.hPred) != 0) {
- throw std::string{ "Create MXNet Predictor failed: "} + mx.MXGetLastError();
+ throw std::string{ "Create MXNet Predictor failed: "} + MXGetLastError();
}
if (d.hPred == nullptr) {
- throw std::string{ "Invalid MXNet Predictor:" } + mx.MXGetLastError() + " Please Try to Upgrade MXNet.";
+ throw std::string{ "Invalid MXNet Predictor:" } + MXGetLastError() + " Please Try to Upgrade MXNet.";
}
} catch (const std::string & error) {
vsapi->setError(out, ("mxnet: " + error).c_str());
|